{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Catboost R Tutorial\n",
    "R kernel for Jupyter Notebook: [link](https://irkernel.github.io/installation/)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading required package: lattice\n",
      "Loading required package: ggplot2\n"
     ]
    }
   ],
   "source": [
    "library(catboost)\n",
    "library(caret)\n",
    "library(titanic)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Make CatBoost Pool\n",
    "\n",
    "### From file\n",
    "\n",
    "Two files are needed to create CatBoost Pool in R:\n",
    "\n",
    "- File with features\n",
    "  \n",
    "```sh\n",
    "> cat adult_train.1000 | head -1\n",
    "1\t28.0\tPrivate\t120135.0\tAssoc-voc\t11.0\tNever-married\tSales\tNot-in-family\tWhite\tFemale\t0.0\t0.0\t40.0\tUnited-States\n",
    "```\n",
    "\n",
    "- Column description file\n",
    "\n",
    "```sh\n",
    "> cat adult.cd | head -3\n",
    "0\tTarget\n",
    "2\tCateg\n",
    "4\tCateg\n",
    "```\n",
    "\n",
    "Column indices are 0-based, column types must be one of:\n",
    "\n",
    "- Target (one column);\n",
    "- Categ;\n",
    "- Num (default type).\n",
    "\n",
    "Indices and description of numeric columns can be omitted."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table>\n",
       "<tbody>\n",
       "\t<tr><td>1            </td><td>1            </td><td>28           </td><td>3.891958e+36 </td><td>120135       </td><td>-1.040168e-34</td><td>11           </td><td>1.261457e+32 </td><td>-371032621056</td><td>8.078709e-34 </td><td>-9.782155e+30</td><td>-9.047987e-38</td><td>0            </td><td>0            </td><td>40           </td><td>1.219625e+24 </td></tr>\n",
       "</tbody>\n",
       "</table>\n"
      ],
      "text/latex": [
       "\\begin{tabular}{llllllllllllllll}\n",
       "\t 1             & 1             & 28            & 3.891958e+36  & 120135        & -1.040168e-34 & 11            & 1.261457e+32  & -371032621056 & 8.078709e-34  & -9.782155e+30 & -9.047987e-38 & 0             & 0             & 40            & 1.219625e+24 \\\\\n",
       "\\end{tabular}\n"
      ],
      "text/markdown": [
       "\n",
       "| 1             | 1             | 28            | 3.891958e+36  | 120135        | -1.040168e-34 | 11            | 1.261457e+32  | -371032621056 | 8.078709e-34  | -9.782155e+30 | -9.047987e-38 | 0             | 0             | 40            | 1.219625e+24  | \n",
       "\n",
       "\n"
      ],
      "text/plain": [
       "     [,1] [,2] [,3] [,4]         [,5]   [,6]          [,7] [,8]        \n",
       "[1,] 1    1    28   3.891958e+36 120135 -1.040168e-34 11   1.261457e+32\n",
       "     [,9]          [,10]        [,11]         [,12]         [,13] [,14] [,15]\n",
       "[1,] -371032621056 8.078709e-34 -9.782155e+30 -9.047987e-38 0     0     40   \n",
       "     [,16]       \n",
       "[1,] 1.219625e+24"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "pool_path = system.file(\"extdata\", \"adult_train.1000\", package = \"catboost\")\n",
    "column_description_path = system.file(\"extdata\", \"adult.cd\", package = \"catboost\")\n",
    "pool <- catboost.load_pool(pool_path, column_description = column_description_path)\n",
    "head(pool, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###  From matrix\n",
    "\n",
    "Categorical features must be transformed to numeric columns using your own method (e.g. string hash). Indices in **`cat_features`** vector are 0-based and can be different from indices in **`.cd`** file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table>\n",
       "<tbody>\n",
       "\t<tr><td>1     </td><td>1     </td><td>28    </td><td>4     </td><td>120135</td><td>9     </td><td>11    </td><td>5     </td><td>12    </td><td>2     </td><td>5     </td><td>1     </td><td>0     </td><td>0     </td><td>40    </td><td>32    </td></tr>\n",
       "</tbody>\n",
       "</table>\n"
      ],
      "text/latex": [
       "\\begin{tabular}{llllllllllllllll}\n",
       "\t 1      & 1      & 28     & 4      & 120135 & 9      & 11     & 5      & 12     & 2      & 5      & 1      & 0      & 0      & 40     & 32    \\\\\n",
       "\\end{tabular}\n"
      ],
      "text/markdown": [
       "\n",
       "| 1      | 1      | 28     | 4      | 120135 | 9      | 11     | 5      | 12     | 2      | 5      | 1      | 0      | 0      | 40     | 32     | \n",
       "\n",
       "\n"
      ],
      "text/plain": [
       "     [,1] [,2] [,3] [,4] [,5]   [,6] [,7] [,8] [,9] [,10] [,11] [,12] [,13]\n",
       "[1,] 1    1    28   4    120135 9    11   5    12   2     5     1     0    \n",
       "     [,14] [,15] [,16]\n",
       "[1,] 0     40    32   "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "pool_path = system.file(\"extdata\", \"adult_train.1000\", package=\"catboost\")\n",
    "\n",
    "column_description_vector = rep('numeric', 15)\n",
    "cat_features <- c(3, 5, 7, 8, 9, 10, 11, 15)\n",
    "for (i in cat_features)\n",
    "    column_description_vector[i] <- 'factor'\n",
    "\n",
    "data <- read.table(pool_path, head = F, sep = \"\\t\", colClasses = column_description_vector, na.strings='NAN')\n",
    "\n",
    "# Transform categorical features to numeric.\n",
    "for (i in cat_features)\n",
    "    data[,i] <- as.numeric(factor(data[,i]))\n",
    "\n",
    "target <- c(1)\n",
    "data_matrix <- as.matrix(data)\n",
    "pool <- catboost.load_pool(as.matrix(data[,-target]),\n",
    "                             label = as.matrix(data[,target]),\n",
    "                             cat_features = cat_features)\n",
    "head(pool, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### From data.frame\n",
    "\n",
    "Categorical features must be converted to factors (use as.factor(), colClasses argument of read.table() etc). Numeric features must be presented as type numeric. Target feature must be presented as type numeric."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table>\n",
       "<tbody>\n",
       "\t<tr><td>1            </td><td>1            </td><td>28           </td><td>3.891958e+36 </td><td>120135       </td><td>-1.040168e-34</td><td>11           </td><td>1.261457e+32 </td><td>-371032621056</td><td>8.078709e-34 </td><td>-9.782155e+30</td><td>-9.047987e-38</td><td>0            </td><td>0            </td><td>40           </td><td>1.219625e+24 </td></tr>\n",
       "</tbody>\n",
       "</table>\n"
      ],
      "text/latex": [
       "\\begin{tabular}{llllllllllllllll}\n",
       "\t 1             & 1             & 28            & 3.891958e+36  & 120135        & -1.040168e-34 & 11            & 1.261457e+32  & -371032621056 & 8.078709e-34  & -9.782155e+30 & -9.047987e-38 & 0             & 0             & 40            & 1.219625e+24 \\\\\n",
       "\\end{tabular}\n"
      ],
      "text/markdown": [
       "\n",
       "| 1             | 1             | 28            | 3.891958e+36  | 120135        | -1.040168e-34 | 11            | 1.261457e+32  | -371032621056 | 8.078709e-34  | -9.782155e+30 | -9.047987e-38 | 0             | 0             | 40            | 1.219625e+24  | \n",
       "\n",
       "\n"
      ],
      "text/plain": [
       "     [,1] [,2] [,3] [,4]         [,5]   [,6]          [,7] [,8]        \n",
       "[1,] 1    1    28   3.891958e+36 120135 -1.040168e-34 11   1.261457e+32\n",
       "     [,9]          [,10]        [,11]         [,12]         [,13] [,14] [,15]\n",
       "[1,] -371032621056 8.078709e-34 -9.782155e+30 -9.047987e-38 0     0     40   \n",
       "     [,16]       \n",
       "[1,] 1.219625e+24"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<table>\n",
       "<tbody>\n",
       "\t<tr><td>1            </td><td>1            </td><td>73           </td><td>-1.220011e+15</td><td>30958        </td><td>-40904704    </td><td>10           </td><td>-2.326434e-34</td><td>-371032621056</td><td>9.094553e-37 </td><td>-9.782155e+30</td><td>-3.163861e-08</td><td>0            </td><td>0            </td><td>25           </td><td>1.219625e+24 </td></tr>\n",
       "</tbody>\n",
       "</table>\n"
      ],
      "text/latex": [
       "\\begin{tabular}{llllllllllllllll}\n",
       "\t 1             & 1             & 73            & -1.220011e+15 & 30958         & -40904704     & 10            & -2.326434e-34 & -371032621056 & 9.094553e-37  & -9.782155e+30 & -3.163861e-08 & 0             & 0             & 25            & 1.219625e+24 \\\\\n",
       "\\end{tabular}\n"
      ],
      "text/markdown": [
       "\n",
       "| 1             | 1             | 73            | -1.220011e+15 | 30958         | -40904704     | 10            | -2.326434e-34 | -371032621056 | 9.094553e-37  | -9.782155e+30 | -3.163861e-08 | 0             | 0             | 25            | 1.219625e+24  | \n",
       "\n",
       "\n"
      ],
      "text/plain": [
       "     [,1] [,2] [,3] [,4]          [,5]  [,6]      [,7] [,8]         \n",
       "[1,] 1    1    73   -1.220011e+15 30958 -40904704 10   -2.326434e-34\n",
       "     [,9]          [,10]        [,11]         [,12]         [,13] [,14] [,15]\n",
       "[1,] -371032621056 9.094553e-37 -9.782155e+30 -3.163861e-08 0     0     25   \n",
       "     [,16]       \n",
       "[1,] 1.219625e+24"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "train_path = system.file(\"extdata\", \"adult_train.1000\", package=\"catboost\")\n",
    "test_path = system.file(\"extdata\", \"adult_test.1000\", package=\"catboost\")\n",
    "\n",
    "column_description_vector = rep('numeric', 15)\n",
    "cat_features <- c(3, 5, 7, 8, 9, 10, 11, 15)\n",
    "for (i in cat_features)\n",
    "    column_description_vector[i] <- 'factor'\n",
    "    \n",
    "train <- read.table(train_path, head = F, sep = \"\\t\", colClasses = column_description_vector, na.strings='NAN')\n",
    "test <- read.table(test_path, head = F, sep = \"\\t\", colClasses = column_description_vector, na.strings='NAN')\n",
    "target <- c(1)\n",
    "train_pool <- catboost.load_pool(data=train[,-target], label = train[,target])\n",
    "test_pool <- catboost.load_pool(data=test[,-target], label = test[,target])\n",
    "head(train_pool, 1)\n",
    "head(test_pool, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Explore pool"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Nrows:  1000 , Ncols:  14 \n",
      "\n",
      "First row: "
     ]
    },
    {
     "data": {
      "text/html": [
       "<table>\n",
       "<tbody>\n",
       "\t<tr><td>1            </td><td>1            </td><td>28           </td><td>3.891958e+36 </td><td>120135       </td><td>-1.040168e-34</td><td>11           </td><td>1.261457e+32 </td><td>-371032621056</td><td>8.078709e-34 </td><td>-9.782155e+30</td><td>-9.047987e-38</td><td>0            </td><td>0            </td><td>40           </td><td>1.219625e+24 </td></tr>\n",
       "</tbody>\n",
       "</table>\n"
      ],
      "text/latex": [
       "\\begin{tabular}{llllllllllllllll}\n",
       "\t 1             & 1             & 28            & 3.891958e+36  & 120135        & -1.040168e-34 & 11            & 1.261457e+32  & -371032621056 & 8.078709e-34  & -9.782155e+30 & -9.047987e-38 & 0             & 0             & 40            & 1.219625e+24 \\\\\n",
       "\\end{tabular}\n"
      ],
      "text/markdown": [
       "\n",
       "| 1             | 1             | 28            | 3.891958e+36  | 120135        | -1.040168e-34 | 11            | 1.261457e+32  | -371032621056 | 8.078709e-34  | -9.782155e+30 | -9.047987e-38 | 0             | 0             | 40            | 1.219625e+24  | \n",
       "\n",
       "\n"
      ],
      "text/plain": [
       "     [,1] [,2] [,3] [,4]         [,5]   [,6]          [,7] [,8]        \n",
       "[1,] 1    1    28   3.891958e+36 120135 -1.040168e-34 11   1.261457e+32\n",
       "     [,9]          [,10]        [,11]         [,12]         [,13] [,14] [,15]\n",
       "[1,] -371032621056 8.078709e-34 -9.782155e+30 -9.047987e-38 0     0     40   \n",
       "     [,16]       \n",
       "[1,] 1.219625e+24"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Last row: "
     ]
    },
    {
     "data": {
      "text/html": [
       "<table>\n",
       "<tbody>\n",
       "\t<tr><td>-1           </td><td>1            </td><td>71           </td><td>-1.816107e-18</td><td>177906       </td><td>5.92781e-19  </td><td>13           </td><td>-2.326434e-34</td><td>-1.816107e-18</td><td>9.094553e-37 </td><td>-9.782155e+30</td><td>-3.163861e-08</td><td>0            </td><td>0            </td><td>10           </td><td>1.219625e+24 </td></tr>\n",
       "</tbody>\n",
       "</table>\n"
      ],
      "text/latex": [
       "\\begin{tabular}{llllllllllllllll}\n",
       "\t -1            & 1             & 71            & -1.816107e-18 & 177906        & 5.92781e-19   & 13            & -2.326434e-34 & -1.816107e-18 & 9.094553e-37  & -9.782155e+30 & -3.163861e-08 & 0             & 0             & 10            & 1.219625e+24 \\\\\n",
       "\\end{tabular}\n"
      ],
      "text/markdown": [
       "\n",
       "| -1            | 1             | 71            | -1.816107e-18 | 177906        | 5.92781e-19   | 13            | -2.326434e-34 | -1.816107e-18 | 9.094553e-37  | -9.782155e+30 | -3.163861e-08 | 0             | 0             | 10            | 1.219625e+24  | \n",
       "\n",
       "\n"
      ],
      "text/plain": [
       "     [,1] [,2] [,3] [,4]          [,5]   [,6]        [,7] [,8]         \n",
       "[1,] -1   1    71   -1.816107e-18 177906 5.92781e-19 13   -2.326434e-34\n",
       "     [,9]          [,10]        [,11]         [,12]         [,13] [,14] [,15]\n",
       "[1,] -1.816107e-18 9.094553e-37 -9.782155e+30 -3.163861e-08 0     0     10   \n",
       "     [,16]       \n",
       "[1,] 1.219625e+24"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Column names: "
     ]
    },
    {
     "data": {
      "text/html": [
       "<ol class=list-inline>\n",
       "\t<li>'V2'</li>\n",
       "\t<li>'V3'</li>\n",
       "\t<li>'V4'</li>\n",
       "\t<li>'V5'</li>\n",
       "\t<li>'V6'</li>\n",
       "\t<li>'V7'</li>\n",
       "\t<li>'V8'</li>\n",
       "\t<li>'V9'</li>\n",
       "\t<li>'V10'</li>\n",
       "\t<li>'V11'</li>\n",
       "\t<li>'V12'</li>\n",
       "\t<li>'V13'</li>\n",
       "\t<li>'V14'</li>\n",
       "\t<li>'V15'</li>\n",
       "</ol>\n"
      ],
      "text/latex": [
       "\\begin{enumerate*}\n",
       "\\item 'V2'\n",
       "\\item 'V3'\n",
       "\\item 'V4'\n",
       "\\item 'V5'\n",
       "\\item 'V6'\n",
       "\\item 'V7'\n",
       "\\item 'V8'\n",
       "\\item 'V9'\n",
       "\\item 'V10'\n",
       "\\item 'V11'\n",
       "\\item 'V12'\n",
       "\\item 'V13'\n",
       "\\item 'V14'\n",
       "\\item 'V15'\n",
       "\\end{enumerate*}\n"
      ],
      "text/markdown": [
       "1. 'V2'\n",
       "2. 'V3'\n",
       "3. 'V4'\n",
       "4. 'V5'\n",
       "5. 'V6'\n",
       "6. 'V7'\n",
       "7. 'V8'\n",
       "8. 'V9'\n",
       "9. 'V10'\n",
       "10. 'V11'\n",
       "11. 'V12'\n",
       "12. 'V13'\n",
       "13. 'V14'\n",
       "14. 'V15'\n",
       "\n",
       "\n"
      ],
      "text/plain": [
       " [1] \"V2\"  \"V3\"  \"V4\"  \"V5\"  \"V6\"  \"V7\"  \"V8\"  \"V9\"  \"V10\" \"V11\" \"V12\" \"V13\"\n",
       "[13] \"V14\" \"V15\""
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# number of rows and colls\n",
    "cat(\"Nrows: \", nrow(train_pool), \", Ncols: \", ncol(train_pool), \"\\n\")\n",
    "# first rows of pool\n",
    "cat(\"\\nFirst row: \")\n",
    "head(train_pool, n = 1)\n",
    "cat(\"\\nLast row: \")\n",
    "tail(train_pool, n = 1)\n",
    "cat(\"\\nColumn names: \")\n",
    "colnames(train_pool)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train model\n",
    "\n",
    "See **`help(catboost.train)`** for all arguments and description. Loss functions: RMSE, MAE, Logloss, CrossEntropy, Quantile, LogLinQuantile, Poisson, MAPE, SMAPE, MultiClass, AUC."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "fit_params <- list(iterations = 100,\n",
    "                   thread_count = 10,\n",
    "                   loss_function = 'Logloss',\n",
    "                   ignored_features = c(4,9),\n",
    "                   border_count = 32,\n",
    "                   depth = 5,\n",
    "                   learning_rate = 0.03,\n",
    "                   l2_leaf_reg = 3.5,\n",
    "                   train_dir = 'train_dir',\n",
    "                   logging_level = 'Silent')\n",
    "model <- catboost.train(train_pool, test_pool, fit_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Predict and evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sample predictions:  0.4286844 0.2480608 0.5529215 0.1709756 0.032663 \n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "      \n",
       "labels  -1   1\n",
       "     0 414 102\n",
       "     1  86 398"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Accuracy:  0.812 \n",
      "\n",
      "Feature importances \n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<dl class=dl-horizontal>\n",
       "\t<dt>V2</dt>\n",
       "\t\t<dd>7.04207435454357</dd>\n",
       "\t<dt>V3</dt>\n",
       "\t\t<dd>0.708119813095853</dd>\n",
       "\t<dt>V4</dt>\n",
       "\t\t<dd>0.832378293966244</dd>\n",
       "\t<dt>V5</dt>\n",
       "\t\t<dd>16.108506622608</dd>\n",
       "\t<dt>V6</dt>\n",
       "\t\t<dd>0</dd>\n",
       "\t<dt>V7</dt>\n",
       "\t\t<dd>24.6381593529844</dd>\n",
       "\t<dt>V8</dt>\n",
       "\t\t<dd>12.0431579359021</dd>\n",
       "\t<dt>V9</dt>\n",
       "\t\t<dd>5.98196334086858</dd>\n",
       "\t<dt>V10</dt>\n",
       "\t\t<dd>1.55807451568542</dd>\n",
       "\t<dt>V11</dt>\n",
       "\t\t<dd>0</dd>\n",
       "\t<dt>V12</dt>\n",
       "\t\t<dd>23.3114562755089</dd>\n",
       "\t<dt>V13</dt>\n",
       "\t\t<dd>0.76877335254996</dd>\n",
       "\t<dt>V14</dt>\n",
       "\t\t<dd>6.07231106027566</dd>\n",
       "\t<dt>V15</dt>\n",
       "\t\t<dd>0.935025082011272</dd>\n",
       "</dl>\n"
      ],
      "text/latex": [
       "\\begin{description*}\n",
       "\\item[V2] 7.04207435454357\n",
       "\\item[V3] 0.708119813095853\n",
       "\\item[V4] 0.832378293966244\n",
       "\\item[V5] 16.108506622608\n",
       "\\item[V6] 0\n",
       "\\item[V7] 24.6381593529844\n",
       "\\item[V8] 12.0431579359021\n",
       "\\item[V9] 5.98196334086858\n",
       "\\item[V10] 1.55807451568542\n",
       "\\item[V11] 0\n",
       "\\item[V12] 23.3114562755089\n",
       "\\item[V13] 0.76877335254996\n",
       "\\item[V14] 6.07231106027566\n",
       "\\item[V15] 0.935025082011272\n",
       "\\end{description*}\n"
      ],
      "text/markdown": [
       "V2\n",
       ":   7.04207435454357V3\n",
       ":   0.708119813095853V4\n",
       ":   0.832378293966244V5\n",
       ":   16.108506622608V6\n",
       ":   0V7\n",
       ":   24.6381593529844V8\n",
       ":   12.0431579359021V9\n",
       ":   5.98196334086858V10\n",
       ":   1.55807451568542V11\n",
       ":   0V12\n",
       ":   23.3114562755089V13\n",
       ":   0.76877335254996V14\n",
       ":   6.07231106027566V15\n",
       ":   0.935025082011272\n",
       "\n"
      ],
      "text/plain": [
       "        V2         V3         V4         V5         V6         V7         V8 \n",
       " 7.0420744  0.7081198  0.8323783 16.1085066  0.0000000 24.6381594 12.0431579 \n",
       "        V9        V10        V11        V12        V13        V14        V15 \n",
       " 5.9819633  1.5580745  0.0000000 23.3114563  0.7687734  6.0723111  0.9350251 "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Tree count:  100 \n"
     ]
    }
   ],
   "source": [
    "calc_accuracy <- function(prediction, expected) {\n",
    "  labels <- ifelse(prediction > 0.5, 1, -1)\n",
    "  accuracy <- sum(labels == expected) / length(labels)\n",
    "  return(accuracy)\n",
    "}\n",
    "\n",
    "prediction <- catboost.predict(model, test_pool, prediction_type = 'Probability')\n",
    "cat(\"Sample predictions: \", sample(prediction, 5), \"\\n\")\n",
    "\n",
    "labels <- catboost.predict(model, test_pool, prediction_type = 'Class')\n",
    "table(labels, test[,target])\n",
    "\n",
    "# works properly only for Logloss\n",
    "accuracy <- calc_accuracy(prediction, test[,target])\n",
    "cat(\"\\nAccuracy: \", accuracy, \"\\n\")\n",
    "\n",
    "# feature splits importances (not finished)\n",
    "\n",
    "cat(\"\\nFeature importances\", \"\\n\")\n",
    "catboost.get_feature_importance(model, train_pool)\n",
    "\n",
    "cat(\"\\nTree count: \", model$tree_count, \"\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can also use **`staged_predict`** function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TRUE \n",
      "TRUE"
     ]
    }
   ],
   "source": [
    "library(iterators)\n",
    "staged_predictions <- catboost.staged_predict(model, test_pool, ntree_start = 2, ntree_end = 5,\n",
    "                                              eval_period = 2, prediction_type = 'Probability')\n",
    "staged_prediction_2_4 = nextElem(staged_predictions) # 2nd and 3rd trees\n",
    "staged_prediction_2_5 = nextElem(staged_predictions) # 2nd, 3rd and 4th trees\n",
    "\n",
    "prediction_2_4 = catboost.predict(model, test_pool, ntree_start = 2, ntree_end = 4, prediction_type = 'Probability')\n",
    "prediction_2_5 = catboost.predict(model, test_pool, ntree_start = 2, ntree_end = 5, prediction_type = 'Probability')\n",
    "cat(all(prediction_2_4 == staged_prediction_2_4), '\\n')\n",
    "cat(all(prediction_2_5 == staged_prediction_2_5))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Useful features\n",
    "\n",
    "If you essentially have a validation set, it's always easier and better to use overfitting detector for more faster training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Simple model tree count:  500 \n",
      "Model with od tree count:  268 \n"
     ]
    }
   ],
   "source": [
    "params_simple <- list(iterations = 500,\n",
    "                      loss_function = 'Logloss',\n",
    "                      train_dir = 'train_dir'\n",
    "                      logging_level = 'Silent')\n",
    "model_simple <- catboost.train(train_pool, test_pool, params_simple)\n",
    "\n",
    "params_with_od <- list(iterations = 500,\n",
    "                       loss_function = 'Logloss',\n",
    "                       train_dir = 'train_dir',\n",
    "                       od_type = 'Iter',\n",
    "                       od_wait = 30\n",
    "                       logging_level = 'Silent')\n",
    "model_with_od <- catboost.train(train_pool, test_pool, params_with_od)\n",
    "\n",
    "cat('Simple model tree count: ', model_simple$tree_count, '\\n')\n",
    "cat('Model with od tree count: ', model_with_od$tree_count, '\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Also you can make predictions using the best model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Simple model accuracy:  0.808 \n",
      "The best model accuracy:  0.822 \n"
     ]
    }
   ],
   "source": [
    "params_simple <- list(iterations = 1000,\n",
    "                      loss_function = 'Logloss',\n",
    "                      train_dir = 'train_dir'\n",
    "                      logging_level = 'Silent')\n",
    "model_simple <- catboost.train(train_pool, test_pool, params_simple)\n",
    "\n",
    "params_best <- list(iterations = 1000,\n",
    "                    loss_function = 'Logloss',\n",
    "                    train_dir = 'train_dir',\n",
    "                    use_best_model = TRUE,\n",
    "                    logging_level = 'Silent')\n",
    "model_best <- catboost.train(train_pool, test_pool, params_best)\n",
    "\n",
    "prediction_simple <- catboost.predict(model_simple, test_pool, prediction_type = 'Probability')\n",
    "prediction_best <- catboost.predict(model_best, test_pool, prediction_type = 'Probability')\n",
    "\n",
    "cat('Simple model accuracy: ', calc_accuracy(prediction_simple, test[,target]), '\\n')\n",
    "cat('The best model accuracy: ', calc_accuracy(prediction_best, test[,target]), '\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Catboosting with caret"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load and preprocess the Titanic dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "set.seed(12345)\n",
    "\n",
    "data <- as.data.frame(as.matrix(titanic_train), stringsAsFactors=TRUE)\n",
    "\n",
    "age_levels <- levels(data$Age)\n",
    "most_frequent_age <- which.max(table(data$Age))\n",
    "data$Age[is.na(data$Age)] <- age_levels[most_frequent_age]\n",
    "\n",
    "drop_columns = c(\"PassengerId\", \"Survived\", \"Name\", \"Ticket\", \"Cabin\")\n",
    "x <- data[,!(names(data) %in% drop_columns)]\n",
    "y <- data[,c(\"Survived\")]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "At training we use 5-fold cross-validation. Also try to find the optimal trees' depth."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "fit_control <- trainControl(method = \"cv\",\n",
    "                            number = 5,\n",
    "                            classProbs = TRUE)\n",
    "\n",
    "grid <- expand.grid(depth = c(4, 6, 8),\n",
    "                    learning_rate = 0.1,\n",
    "                    iterations = 100,\n",
    "                    l2_leaf_reg = 0.1,\n",
    "                    rsm = 0.95,\n",
    "                    border_count = 64)\n",
    "\n",
    "model <- train(x, as.factor(make.names(y)),\n",
    "                method = catboost.caret,\n",
    "                logging_level = 'Silent', preProc = NULL,\n",
    "                tuneGrid = grid, trControl = fit_control)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Print information about model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Catboost \n",
      "\n",
      "891 samples\n",
      "  7 predictor\n",
      "  2 classes: 'X0', 'X1' \n",
      "\n",
      "No pre-processing\n",
      "Resampling: Cross-Validated (5 fold) \n",
      "Summary of sample sizes: 714, 712, 713, 713, 712 \n",
      "Resampling results across tuning parameters:\n",
      "\n",
      "  depth  Accuracy   Kappa    \n",
      "  4      0.8147803  0.5921606\n",
      "  6      0.8035946  0.5740471\n",
      "  8      0.8136628  0.5961711\n",
      "\n",
      "Tuning parameter 'learning_rate' was held constant at a value of 0.1\n",
      "\n",
      "Tuning parameter 'rsm' was held constant at a value of 0.95\n",
      "Tuning\n",
      " parameter 'border_count' was held constant at a value of 64\n",
      "Accuracy was used to select the optimal model using  the largest value.\n",
      "The final values used for the model were depth = 4, learning_rate =\n",
      " 0.1, iterations = 100, l2_leaf_reg = 0.1, rsm = 0.95 and border_count = 64.\n",
      "custom variable importance\n",
      "\n",
      "         Overall\n",
      "Sex       26.832\n",
      "Fare      22.384\n",
      "Pclass    16.507\n",
      "Parch     14.540\n",
      "Age        7.522\n",
      "Embarked   7.473\n",
      "SibSp      4.741\n"
     ]
    }
   ],
   "source": [
    "print(model)\n",
    "\n",
    "importance <- varImp(model, scale = FALSE)\n",
    "print(importance)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And predict the result."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table>\n",
       "<thead><tr><th scope=col>X0</th><th scope=col>X1</th></tr></thead>\n",
       "<tbody>\n",
       "\t<tr><td>0.93941009</td><td>0.06058991</td></tr>\n",
       "\t<tr><td>0.01378653</td><td>0.98621347</td></tr>\n",
       "\t<tr><td>0.29743151</td><td>0.70256849</td></tr>\n",
       "\t<tr><td>0.02529839</td><td>0.97470161</td></tr>\n",
       "\t<tr><td>0.95701379</td><td>0.04298621</td></tr>\n",
       "\t<tr><td>0.91183457</td><td>0.08816543</td></tr>\n",
       "</tbody>\n",
       "</table>\n"
      ],
      "text/latex": [
       "\\begin{tabular}{r|ll}\n",
       " X0 & X1\\\\\n",
       "\\hline\n",
       "\t 0.93941009 & 0.06058991\\\\\n",
       "\t 0.01378653 & 0.98621347\\\\\n",
       "\t 0.29743151 & 0.70256849\\\\\n",
       "\t 0.02529839 & 0.97470161\\\\\n",
       "\t 0.95701379 & 0.04298621\\\\\n",
       "\t 0.91183457 & 0.08816543\\\\\n",
       "\\end{tabular}\n"
      ],
      "text/markdown": [
       "\n",
       "X0 | X1 | \n",
       "|---|---|---|---|---|---|\n",
       "| 0.93941009 | 0.06058991 | \n",
       "| 0.01378653 | 0.98621347 | \n",
       "| 0.29743151 | 0.70256849 | \n",
       "| 0.02529839 | 0.97470161 | \n",
       "| 0.95701379 | 0.04298621 | \n",
       "| 0.91183457 | 0.08816543 | \n",
       "\n",
       "\n"
      ],
      "text/plain": [
       "  X0         X1        \n",
       "1 0.93941009 0.06058991\n",
       "2 0.01378653 0.98621347\n",
       "3 0.29743151 0.70256849\n",
       "4 0.02529839 0.97470161\n",
       "5 0.95701379 0.04298621\n",
       "6 0.91183457 0.08816543"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "head(predict(model, type = 'prob'))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "R",
   "language": "R",
   "name": "ir"
  },
  "language_info": {
   "codemirror_mode": "r",
   "file_extension": ".r",
   "mimetype": "text/x-r-source",
   "name": "R",
   "pygments_lexer": "r",
   "version": "3.4.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
